{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6d742657-8e6a-439d-a65a-3c76e73c8810",
   "metadata": {},
   "source": [
    "# Logistic Regression with SPU"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5ddc86c8",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ff9c4ba5-a4a8-46ec-96e3-611ef5e71dc5",
   "metadata": {},
   "source": [
    "[SPU](https://www.secretflow.org.cn/docs/spu/en/) is a domain specific compiler and runtime suite, which provides provable secure computation service. SPU compiler uses [XLA](https://www.tensorflow.org/xla) as its front-end IR, which supports diverse AI framework (like Tensorflow, JAX and PyTorch). SPU compiler translates XLA to an IR which could be interpreted by the SPU runtime. Currently SPU team highly recommends using [JAX](https://github.com/google/jax) as the frontend.\n",
    "\n",
    "## Learning Objectives:\n",
    "\n",
    "After doing this lab, you'll know how to:\n",
    "\n",
    "* How to write a Logistic Regression trainning program with JAX.\n",
    "* How to convert a JAX program to an SPU(MPC) program painlessly.\n",
    "\n",
    "In this lab, we select [Breast Cancer](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(diagnostic)) as the dataset. We need to decide whether cancer is malignant or benign with 30 features. In the MPC program, two parties will train the model jointly and each party would provide half of features(15).\n",
    "\n",
    "While, first, let's just forget MPC settings and just write a Logistic Regression training program with JAX directly."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "629aad7f",
   "metadata": {},
   "source": [
    "## Train a model with JAX"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d9774618-798e-421f-9be0-4fb954d7c710",
   "metadata": {},
   "source": [
    "### Load the Dataset\n",
    "\n",
    "\n",
    "We are going to split the whole dataset into train and test subsets after normalization with `breast_cancer`.\n",
    "* if `train` is `True`, returns train subsets. In order to simulate training with vertical dataset splitting, the `party_id` is provided.\n",
    "* else, returns test subsets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "364a380e-9cea-42e3-8ab5-06635df97478",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import Normalizer\n",
    "\n",
    "\n",
    "def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):\n",
    "    x, y = load_breast_cancer(return_X_y=True)\n",
    "    x = (x - np.min(x)) / (np.max(x) - np.min(x))\n",
    "    x_train, x_test, y_train, y_test = train_test_split(\n",
    "        x, y, test_size=0.2, random_state=42\n",
    "    )\n",
    "\n",
    "    if train:\n",
    "        if party_id:\n",
    "            if party_id == 1:\n",
    "                return x_train[:, :15], _\n",
    "            else:\n",
    "                return x_train[:, 15:], y_train\n",
    "        else:\n",
    "            return x_train, y_train\n",
    "    else:\n",
    "        return x_test, y_test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fb400b1f-f516-4348-8812-e05a2e8f5e17",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Define the Model\n",
    "\n",
    "First, let's define the loss function, which is a negative log-likelihood in our case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ffb9d32-4150-41cb-a6e9-65cbccf6568a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + jnp.exp(-x))\n",
    "\n",
    "\n",
    "# Outputs probability of a label being true.\n",
    "def predict(W, b, inputs):\n",
    "    return sigmoid(jnp.dot(inputs, W) + b)\n",
    "\n",
    "\n",
    "# Training loss is the negative log-likelihood of the training examples.\n",
    "def loss(W, b, inputs, targets):\n",
    "    preds = predict(W, b, inputs)\n",
    "    label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
    "    return -jnp.mean(jnp.log(label_probs))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7eaaf2b4-6491-471a-85fb-d63236e4a564",
   "metadata": {},
   "source": [
    "Second, let's define a single train step with SGD optimizer. Just to remind you, x1 represents 15 features from one party while x2 represents the other 15 features from the other party."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "71e01788-804f-4378-a268-b84c3940320a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import grad\n",
    "\n",
    "\n",
    "def train_step(W, b, x1, x2, y, learning_rate):\n",
    "    x = jnp.concatenate([x1, x2], axis=1)\n",
    "    Wb_grad = grad(loss, (0, 1))(W, b, x, y)\n",
    "    W -= learning_rate * Wb_grad[0]\n",
    "    b -= learning_rate * Wb_grad[1]\n",
    "    return W, b"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "eac6baac-280c-44b4-b8a0-9a6eb3a76908",
   "metadata": {
    "tags": []
   },
   "source": [
    "\n",
    "Last, let's build everything together as a `fit` method which returns the model and losses of each epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f15b0c71-5761-41aa-b70d-9ba242f8a795",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):\n",
    "    for _ in range(epochs):\n",
    "        W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)\n",
    "    return W, b"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5dd8abd5",
   "metadata": {},
   "source": [
    "### Validate the Model\n",
    "\n",
    "We could use the AUC to validate a binary classification model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "132fcee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "\n",
    "def validate_model(W, b, X_test, y_test):\n",
    "    y_pred = predict(W, b, X_test)\n",
    "    return roc_auc_score(y_test, y_pred)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "581fac73",
   "metadata": {},
   "source": [
    "### Have a try!\n",
    "\n",
    "Let's put everything we have together and train a LR model!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ad002d29",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc=0.9878807730101539\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# Load the data\n",
    "x1, _ = breast_cancer(party_id=1, train=True)\n",
    "x2, y = breast_cancer(party_id=2, train=True)\n",
    "\n",
    "# Hyperparameter\n",
    "W = jnp.zeros((30,))\n",
    "b = 0.0\n",
    "epochs = 10\n",
    "learning_rate = 1e-2\n",
    "\n",
    "# Train the model\n",
    "W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)\n",
    "\n",
    "# Validate the model\n",
    "X_test, y_test = breast_cancer(train=False)\n",
    "auc = validate_model(W, b, X_test, y_test)\n",
    "print(f'auc={auc}')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "cb11163f",
   "metadata": {},
   "source": [
    "Just remember the AUC here since we would like to do the similar thing with SPU!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e5c4d4b1",
   "metadata": {},
   "source": [
    "## Train a Model with SPU\n",
    "\n",
    "At this part, we are going to show you how to do the similar training with MPC securely!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8fe0ded7-d861-4506-9b87-4342629a8c5e",
   "metadata": {},
   "source": [
    "### Init the Environment\n",
    "\n",
    "We are going to init three virtual devices on our physical environment.\n",
    "- alice, bob：Two PYU devices for local plaintext computation.\n",
    "- spu：SPU device consists with alice and bob for MPC secure computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14c4f276-7eb0-4dd9-a39b-06b588bcc5bf",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Run secretflow in simulation mode.\n",
      "2023-03-14 12:59:36,141\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob'], address='local')\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "cecb025b",
   "metadata": {},
   "source": [
    "### Load the Dataset\n",
    "\n",
    "we instruct alice and bob to load the train subset respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "149864cc-ea35-4f3e-bb7a-2247ef835ea8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<secretflow.device.device.pyu.PYUObject at 0x7f74b6586a30>,\n",
       " <secretflow.device.device.pyu.PYUObject at 0x7f74b58fd6d0>,\n",
       " <secretflow.device.device.pyu.PYUObject at 0x7f74b58fd8e0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1, _ = alice(breast_cancer)(party_id=1)\n",
    "x2, y = bob(breast_cancer)(party_id=2)\n",
    "\n",
    "x1, x2, y"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "019e3f74-cfa8-458b-b3ca-6d3431d0e64d",
   "metadata": {},
   "source": [
    "Before training, we need to pass hyperparamters and all data to SPU device. SecretFlow provides two methods:\n",
    "- secretflow.to: transfer a PythonObject or DeviceObject to a specific device.\n",
    "- DeviceObject.to: transfer the DeviceObject to a specific device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a899ee65-a963-4f75-b5bf-e97ed885b52f",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=1517660)\u001b[0m WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=1515629)\u001b[0m WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=1517309)\u001b[0m WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=1516708)\u001b[0m WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "device = spu\n",
    "\n",
    "W = jnp.zeros((30,))\n",
    "b = 0.0\n",
    "\n",
    "W_, b_, x1_, x2_, y_ = (\n",
    "    sf.to(alice, W).to(device),\n",
    "    sf.to(alice, b).to(device),\n",
    "    x1.to(device),\n",
    "    x2.to(device),\n",
    "    y.to(device),\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "71fabe33-8a83-4cd7-97ef-66743ce4d492",
   "metadata": {},
   "source": [
    "### Train the model\n",
    "\n",
    "Now we are ready to train a LR model with SPU. After training, losses and model are SPUObjects which are still secret."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d2b3a1a2-f12d-4fe7-bea5-5a6e4faf01ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<secretflow.device.device.spu.SPUObject at 0x7f7608ad88e0>,\n",
       " <secretflow.device.device.spu.SPUObject at 0x7f7608ad83a0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(SPURuntime pid=1525949)\u001b[0m 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "\u001b[2m\u001b[36m(SPURuntime pid=1525959)\u001b[0m 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n"
     ]
    }
   ],
   "source": [
    "W_, b_ = device(\n",
    "    fit,\n",
    "    static_argnames=['epochs'],\n",
    "    num_returns_policy=sf.device.SPUCompilerNumReturnsPolicy.FROM_USER,\n",
    "    user_specified_num_returns=2,\n",
    ")(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)\n",
    "\n",
    "W_, b_"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ea3a5d7b-f16d-4330-a96f-5a4f6b0a2996",
   "metadata": {},
   "source": [
    "### Reveal the result\n",
    "\n",
    "In order to check the trained model, we need to convert SPUObject(secret) to Python object(plaintext). SecretFlow provide `sf.reveal` to convert any DeviceObject to Python object.\n",
    "\n",
    "> Be care with `sf.reveal`，since it may result in secret leak。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a58ff0dc-e2ac-4dfe-b562-4f3e42059c02",
   "metadata": {},
   "source": [
    "Finally, let's validate the model with AUC."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9bdffb27-4a44-4ee6-ab00-30d1b906fb38",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc=0.987880773010154\n"
     ]
    }
   ],
   "source": [
    "auc = validate_model(sf.reveal(W_), sf.reveal(b_), X_test, y_test)\n",
    "print(f'auc={auc}')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ac49ff9c",
   "metadata": {},
   "source": [
    "You may find the model from SPU training program achieve the same AUC as JAX program.\n",
    "\n",
    "This is the end of lab."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "de4c945f5346493decaa0ea82289843a7da2415616b96b9f4b104111cc0c19ad"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
